{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "inclusive-adjustment",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from sagemaker.pytorch import PyTorch\n",
    "from utils import wait_till_all_done, CLUSTER_DATASETS\n",
    "\n",
    "role = 'arn:aws:iam::157264205850:role/dejiao-sagemaker-run'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "guilty-fifty",
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_models = [\"distilroberta\", \"distilbert\"]\n",
    "lr_params = [(1e-05, 100), (3e-05, 100)]\n",
    "contrast_types = [\"HardNeg\", \"Orig\"]\n",
    "temps = [0.5, 0.1]\n",
    "objectives = [\"contrastive\", \"SCCL\"]\n",
    "datasets = [\"agnews\", \"searchsnippets\", \"stackoverflow\", \"biomedical\", \"tweet\", \"googleT\", \"googleS\", \"googleTS\"]\n",
    "\n",
    "use_pretrain=\"SBERT\"\n",
    "augtype=\"virtual\"\n",
    "batch_size = 512\n",
    "maxlen = 32\n",
    "maxiter = 1000\n",
    "eta = 10\n",
    "base_job_name = \"SCCLv2-distil-hpo\"\n",
    "s3_dataroot = \"s3://dejiao-experiment-east1/datasets/psc_shorttext/\"\n",
    "s3_resdir = \"s3://dejiao-experiment-east1/train/SCCL-SBERT/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afraid-ranch",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "submit: 1\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 2\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 3\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 4\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 5\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 6\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 7\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 8\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 9\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 10\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 11\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 12\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 13\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 14\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 15\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 16\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 17\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 18\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 19\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 20\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 21\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 22\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 23\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 24\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 25\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 26\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 27\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 28\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 29\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 30\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 31\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 32\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 33\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 34\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 35\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 36\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 37\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 38\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 39\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 40\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 41\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 42\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 43\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 44\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 45\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 46\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 47\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 48\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 49\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 50\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 51\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 52\n",
      "distilbert \t lr: 1e-05\n",
      "submit: 53\n",
      "distilroberta \t lr: 1e-05\n",
      "submit: 54\n",
      "distilbert \t lr: 1e-05\n"
     ]
    }
   ],
   "source": [
    "idx = 1  \n",
    "for lr, lr_scale in lr_params:\n",
    "    for temperature in temps:\n",
    "        wait_till_all_done(base_job_name)\n",
    "        \n",
    "        for datakey in datasets:    \n",
    "            for objective in objectives:\n",
    "                for ctype in contrast_types:\n",
    "                    for bert in bert_models:\n",
    "\n",
    "                        dataname, num_classes, text, label = CLUSTER_DATASETS[datakey]\n",
    "\n",
    "                        hyperparameters = {\n",
    "                            'train_instance': \"sagemaker\",\n",
    "                            'use_pretrain': use_pretrain,\n",
    "                            'datapath': s3_dataroot,\n",
    "                            'dataname': dataname, \n",
    "                            'text': text,\n",
    "                            'label': label,\n",
    "                            'num_classes': num_classes,\n",
    "                            'bert': bert,\n",
    "                            'objective': objective,\n",
    "                            'eta': eta, \n",
    "                            'augtype': 'virtual',\n",
    "                            'contrast_type': ctype,\n",
    "                            'lr': lr,\n",
    "                            'lr_scale': lr_scale,\n",
    "                            'lr_scale_contrast': '100',\n",
    "                            'batch_size': batch_size,\n",
    "                            'max_length': maxlen,\n",
    "                            'temperature': temperature,\n",
    "                            'max_iter': maxiter,\n",
    "                            'print_freq': '50',\n",
    "                            'seed': '0',\n",
    "                            'gpuid': '0',\n",
    "                            'resdir': '/tmp/resnli/PaperTempRes/',\n",
    "                            's3_resdir': s3_resdir,\n",
    "                        }\n",
    "\n",
    "                        try:\n",
    "                            estimator = PyTorch(entry_point='main.py',\n",
    "                                                source_dir='/home/ec2-user/efs/dejiao-explore/code/SCCL/',\n",
    "                                                role=role,\n",
    "                                                instance_count=1,\n",
    "                                                instance_type='ml.p3.2xlarge',\n",
    "                                                image_uri='157264205850.dkr.ecr.us-east-1.amazonaws.com/vncl-transformers-p17',\n",
    "                                                base_job_name = base_job_name,\n",
    "                                                hyperparameters=hyperparameters,\n",
    "                                                output_path='s3://dejiao-sagemaker-east1/SCCL/',\n",
    "                                                framework_version='1.8.1',\n",
    "                                                py_version = 'py3',\n",
    "                                                debugger_hook_config=False,\n",
    "                                                max_run=3 * 24 * 60 * 60,\n",
    "                                                volume_size = 500,\n",
    "                                                )\n",
    "\n",
    "                            estimator.fit(wait=False)\n",
    "                            print(\"submit: {}\".format(idx))\n",
    "                        except:\n",
    "                            print(\"submit: {} failed\".format(idx))\n",
    "\n",
    "                        time.sleep(2)\n",
    "                        idx += 1\n",
    "\n",
    "                        print(bert, \"\\t lr:\", lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "accepting-insight",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_pytorch_latest_p37)",
   "language": "python",
   "name": "conda_pytorch_latest_p37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
